{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DeepSphere using ModelNet40 dataset\n",
    "### Benchmark with Cohen method S2CNN[[1]](http://arxiv.org/abs/1801.10130) and Esteves method[[2]](http://arxiv.org/abs/1711.06721) and others spherical CNNs\n",
    "Multi-class classification of 3D objects, using the interesting property of rotation equivariance.\n",
    "\n",
    "The 3D objects are projected on a unit sphere.\n",
    "Cohen and Esteves use equiangular sampling, while our method use a HEAlpix sampling\n",
    "\n",
    "Several features are collected:\n",
    "* projection ray length (from sphere border to intersection [0, 2])\n",
    "* cos/sin with surface normal\n",
    "* same features using the convex hull of the 3D object"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0.1 Load libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import sys\n",
    "sys.path.append('../../')\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"  # change to chosen GPU to use, nothing if work on CPU\n",
    "\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import healpy as hp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepsphere import models, experiment_helper, plot, utils\n",
    "from deepsphere.data import LabeledDatasetWithNoise, LabeledDataset\n",
    "import hyperparameters\n",
    "\n",
    "from load_MN40 import plot_healpix_projection, ModelNet40DatasetTF, ModelNet40DatasetCache"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0.2 Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Nside = 32\n",
    "exp='norot' # in ['rot', 'norot', 'pert', 'Z']\n",
    "datapath = '../../../data/ModelNet40/' # localisation of the .OFF files\n",
    "proc_path = datapath[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "augmentation = 1        # number of element per file (1 = no augmentation of dataset)\n",
    "nfeat = 6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Test projection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import trimesh\n",
    "dataset = '/train/'\n",
    "clas = 'airplane'\n",
    "mesh = trimesh.load_mesh(datapath+clas+dataset+clas+\"_0119.off\")\n",
    "mesh.remove_degenerate_faces()\n",
    "mesh.remove_duplicate_faces()\n",
    "mesh.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from load_MN40 import rotmat, rnd_rot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh = trimesh.load_mesh(datapath+clas+dataset+clas+\"_0338.off\")\n",
    "mesh.remove_degenerate_faces()\n",
    "mesh.fix_normals()\n",
    "mesh.fill_holes()\n",
    "mesh.remove_duplicate_faces()\n",
    "mesh.remove_infinite_values()\n",
    "mesh.remove_unreferenced_vertices()\n",
    "\n",
    "mesh.apply_translation(-mesh.centroid)\n",
    "r = np.max(np.linalg.norm(mesh.vertices, axis=-1))\n",
    "mesh.apply_scale(1 / r)\n",
    "\n",
    "mesh.apply_transform(rnd_rot(z=0, c=0))\n",
    "\n",
    "r = np.max(np.linalg.norm(mesh.vertices, axis=-1))\n",
    "mesh.apply_scale(0.99 / r)\n",
    "mesh.remove_degenerate_faces()\n",
    "mesh.fix_normals()\n",
    "mesh.fill_holes()\n",
    "mesh.remove_duplicate_faces()\n",
    "mesh.remove_infinite_values()\n",
    "mesh.remove_unreferenced_vertices()\n",
    "mesh.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = '/train/'\n",
    "_class = 'table'\n",
    "plot_healpix_projection(datapath+_class+dataset+_class+\"_0001.off\", 32, rotp = False, rot = (90,0,0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_rot_dataset = ModelNet40DatasetCache(datapath, 'train', nside=Nside, nfeat=nfeat, augmentation=3, nfile=None, \n",
    "                                              experiment='deepsphere_rot_notr')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_rot_tr_dataset = ModelNet40DatasetCache(datapath, 'train', nside=Nside, nfeat=nfeat, augmentation=3, nfile=None, \n",
    "                                           experiment='deepsphere_rot')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_tr_dataset = ModelNet40DatasetCache(datapath, 'train', nside=Nside, nfeat=nfeat, augmentation=3, nfile=None, \n",
    "                                       experiment='deepsphere')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = ModelNet40DatasetCache(datapath, 'train', nside=Nside, nfeat=nfeat, augmentation=1, nfile=None, \n",
    "                                          experiment='deepsphere_notr')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_Z_dataset = ModelNet40DatasetCache(datapath, 'train', nside=Nside, nfeat=nfeat, augmentation=3, nfile=None, \n",
    "                                         experiment='deepsphere_Z')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Better to keep validation and testing set in RAM, but not always possible"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_tr_dataset = ModelNet40DatasetCache(datapath, 'test', nside=Nside, nfeat=nfeat, augmentation=3, nfile=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = ModelNet40DatasetCache(datapath, 'test', nside=Nside, nfeat=nfeat, augmentation=1, nfile=None,\n",
    "                                        experiment='deepsphere_notr')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_rot_tr_dataset = ModelNet40DatasetCache(datapath, 'test', nside=Nside, \n",
    "                                       nfeat=nfeat, experiment='deepsphere_rot', augmentation=3, nfile=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_rot_dataset = ModelNet40DatasetCache(datapath, 'test', nside=Nside, \n",
    "                                       nfeat=nfeat, experiment='deepsphere_rot_notr', augmentation=3, nfile=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_Z_dataset = ModelNet40DatasetCache(datapath, 'test', nside=Nside, \n",
    "                                       nfeat=nfeat, experiment='deepsphere_Z', augmentation=3, nfile=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Try do make a tensorflow dataset object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment = 'deepsphere'+('_rot' if exp == 'rot' else '')+('_Z' if exp == 'Z' else '')+('_notr' if 'pert' not in exp and exp != 'Z' else '')\n",
    "train_TFDataset = ModelNet40DatasetTF(datapath, 'train', nside=Nside,\n",
    "                                      nfeat=nfeat, augmentation=augmentation, nfile=None, experiment=experiment)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_TFDataset.N"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 compute stats and test dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from load_MN40 import compute_mean_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "compute_mean_std(train_dataset, 'train', datapath, Nside)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = train_TFDataset.get_tf_dataset(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "\n",
    "#dataset = tf_dataset_file(datapath, dataset, file_pattern, 32, Nside, augmentation)\n",
    "data_next = dataset.make_one_shot_iterator().get_next()\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "steps = train_TFDataset.N // 32 + 1\n",
    "cm = plt.cm.RdBu_r\n",
    "cm.set_under('w')\n",
    "with tf.Session(config=config) as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    for i in tqdm(range(steps)):\n",
    "        data, label = sess.run(data_next)\n",
    "        im1 = data[0,:,0]\n",
    "        cmin = np.nanmin(im1)\n",
    "        cmax = np.nanmax(im1)\n",
    "        hp.orthview(im1, rot=(0,0,0), title=train_TFDataset.classes[label[0]], nest=True, cmap=cm, min=cmin, max=cmax)\n",
    "        plt.figure()\n",
    "        if i > 2:\n",
    "            suffix = train_TFDataset.classes[label[0]]\n",
    "            break\n",
    "#     except tf.errors.OutOfRangeError:\n",
    "#         print(\"Done\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform(data, phi=None, theta=None):\n",
    "    batch_size, npix, nfeat = data.shape\n",
    "    if theta is None or phi is None:\n",
    "        phi = np.random.rand() * 2 * np.pi\n",
    "        theta = np.random.rand() * np.pi\n",
    "    nside = hp.npix2nside(npix)\n",
    "\n",
    "    # Get theta, phi for non-rotated map\n",
    "    t,p = hp.pix2ang(nside, np.arange(npix), nest=True) #theta, phi\n",
    "\n",
    "    # Define a rotator\n",
    "    r = hp.Rotator(deg=False, rot=[phi, theta])\n",
    "\n",
    "    # Get theta, phi under rotated co-ordinates\n",
    "    trot, prot = r(t,p)\n",
    "\n",
    "    # Interpolate map onto these co-ordinates\n",
    "    new_data = np.zeros(data.shape)\n",
    "    for b in range(batch_size):\n",
    "        for f in range(nfeat):\n",
    "            new_data[b,:,f] = hp.get_interp_val(data[b,:,f], trot, prot, nest=True)\n",
    "\n",
    "    return new_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform_equator(data):\n",
    "    return transform(data, 0, np.pi/2).astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform_shift(data):\n",
    "    \"\"\"\n",
    "    90° rotation around poles (natural Z-axis)\n",
    "    \"\"\"\n",
    "    batch_size, npix, nfeat = data.shape\n",
    "    new_data = data.copy()\n",
    "    nside = hp.npix2nside(npix)\n",
    "    theta, _ = hp.pix2ang(nside, range(npix))\n",
    "    theta_u = np.unique(theta)\n",
    "    for b in range(batch_size):\n",
    "        for f in range(nfeat):\n",
    "            new_data[b, :, f] = hp.reorder(data[b, :, f], n2r=True)\n",
    "            for t in theta_u:\n",
    "                ligne_ind = np.where(theta==t)[0]\n",
    "                ligne_ind_roll = np.roll(ligne_ind, len(ligne_ind)//4)\n",
    "                new_data[b,ligne_ind_roll,f] = new_data[b,ligne_ind,f]\n",
    "            new_data[b, :, f] = hp.reorder(new_data[b, :, f], r2n = True)\n",
    "    return new_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform_inverse(data):\n",
    "    \"\"\"\n",
    "    180° rotation around X-axis\n",
    "    \"\"\"\n",
    "    batch_size, npix, nfeat = data.shape\n",
    "    data_c = data.copy()\n",
    "    new_data = data.copy()\n",
    "    new_data[:] = -10\n",
    "    nside = hp.npix2nside(npix)\n",
    "    theta, _ = hp.pix2ang(nside, range(npix))\n",
    "    theta_u = np.unique(theta)\n",
    "    for b in range(batch_size):\n",
    "        for f in range(nfeat):\n",
    "            data_c[b, :, f] = hp.reorder(data[b, :, f], n2r=True)\n",
    "            for i, (t, t_end) in enumerate(zip(theta_u, theta_u[::-1])):\n",
    "                ligne_ind = np.where(theta==t)[0]\n",
    "                ligne_ind_roll = np.where(theta==t_end)[0][::-1]\n",
    "                if i > len(theta_u)/4 and i < len(theta_u)*3/4:\n",
    "                    ligne_ind_roll = np.roll(ligne_ind_roll, (i+1)%2)\n",
    "                new_data[b,ligne_ind_roll,f] = data_c[b,ligne_ind,f]\n",
    "            new_data[b, :, f] = hp.reorder(new_data[b, :, f], r2n = True)\n",
    "    return new_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hp.orthview(im1, rot=(0,0,0), title=suffix, nest=True, cmap=cm, min=cmin, max=cmax)\n",
    "plt.figure()\n",
    "im2 = transform_shift(im1[np.newaxis,:,np.newaxis])\n",
    "hp.orthview(im2[0,:,0], rot=(0,0,0), title=suffix, nest=True, cmap=cm, min=cmin, max=cmax)\n",
    "plt.figure()\n",
    "im2 = transform_inverse(im2)\n",
    "hp.orthview(im2[0,:,0], rot=(0,0,0), title=suffix, nest=True, cmap=cm, min=cmin, max=cmax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 create dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "size = 1 # 32\n",
    "steps = test_dataset.N // size + 1\n",
    "data_iter = test_dataset.iter(size)\n",
    "cm = plt.cm.RdBu_r\n",
    "cm.set_under('w')\n",
    "for i in tqdm(range(steps)):\n",
    "    data, label = next(data_iter)\n",
    "    im1 = data[0,:,0]\n",
    "#     if np.std(im1)>2:\n",
    "#         print(np.std(im1))\n",
    "#     cmin = np.nanmin(im1)\n",
    "#     cmax = np.nanmax(im1)\n",
    "#     hp.orthview(im1, rot=(0,0,0), title=test_dataset.classes[label[0]], nest=True, cmap=cm, min=cmin, max=cmax)\n",
    "#     plt.figure()\n",
    "#     if i > 10:\n",
    "#         break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 1.3 Informations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Shuffle the training dataset and print the classes distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nclass = train_TFDataset.nclass\n",
    "num_elem = train_TFDataset.N\n",
    "print('number of class:',nclass,'\\nnumber of elements:',num_elem)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 Classification using DeepSphere"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = 'MN40_{}_{}feat_{}aug_{}sides'.format(exp, nfeat, augmentation, Nside)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load model with hyperparameters chosen.\n",
    "For each experiment, a new EXP_NAME is chosen, and new hyperparameters are store.\n",
    "All informations are present 'DeepSphere/Shrec17/experiments.md'\n",
    "The fastest way to reproduce an experiment is to revert to the commit of the experiment to load the correct files and notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adding a layer in the fully connected can be beneficial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = hyperparameters.get_params_mn40(train_TFDataset.N, EXP_NAME, Nside, nclass, \n",
    "                                                  nfeat_in=nfeat, architecture='CNN')  # get_params_shrec17_optim\n",
    "params[\"tf_dataset\"] = train_TFDataset.get_tf_dataset(params[\"batch_size\"])\n",
    "#params[\"std\"] = [0.001, 0.005, 0.0125, 0.05, 0.15, 0.5] # [0.00002, 0.0002, 0.001, 0.005, 0.0125, 0.05] # best std for nside = 32\n",
    "#params[\"full\"] = [True]*6\n",
    "model = models.deepsphere(**params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shutil.rmtree('../../summaries/{}/'.format(EXP_NAME), ignore_errors=True)\n",
    "shutil.rmtree('../../checkpoints/{}/'.format(EXP_NAME), ignore_errors=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Find a correct learning rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# backup = params.copy()\n",
    "\n",
    "# params, learning_rate = utils.test_learning_rates(params, train_TFDataset.N, 1e-6, 1e-1, num_epochs=20)\n",
    "\n",
    "# shutil.rmtree('summaries/{}/'.format(params['dir_name']), ignore_errors=True)\n",
    "# shutil.rmtree('checkpoints/{}/'.format(params['dir_name']), ignore_errors=True)\n",
    "\n",
    "# model = models.deepsphere(**params)\n",
    "# _, loss_validation, _, _ = model.fit(train_TFDataset, val_dataset, use_tf_dataset=True, cache=True)\n",
    "\n",
    "# params.update(backup)\n",
    "\n",
    "# plt.semilogx(learning_rate, loss_validation, '.-')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shutil.rmtree('summaries/lr_finder/', ignore_errors=True)\n",
    "# shutil.rmtree('checkpoints/lr_finder/', ignore_errors=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "0.9 seems to be a good learning rate for SGD with current parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 Train Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"the number of parameters in the model is: {:,}\".format(model.get_nbr_var()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "accuracy_validation, loss_validation, loss_training, t_step, t_batch = model.fit(train_TFDataset, \n",
    "                                                                                 test_dataset, \n",
    "                                                                                 use_tf_dataset=True, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plot.plot_loss(loss_training, loss_validation, t_step, params['eval_frequency'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(train_rot_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(train_rot_tr_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(train_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(train_tr_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(train_Z_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 test network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.evaluate(test_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset.set_transform(transform_shift)\n",
    "print(model.evaluate(test_dataset, None, cache=True))\n",
    "test_dataset.set_transform(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(test_rot_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(test_rot_tr_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(test_tr_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(test_Z_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 exploration of results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions, loss = model.predict(test_dataset, None, cache=True)\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions_rot, loss = model.predict(test_rot_dataset, None, cache=True)\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_test = test_dataset.get_labels()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### class attribution for \"flower_pot\" shapes\n",
    "\n",
    "from collections import Counter\n",
    "from sklearn.metrics import accuracy_score\n",
    "# hist = Counter(predictions)\n",
    "hist = Counter(labels_test)\n",
    "tot = 0\n",
    "accuracy_class = np.empty((40,))\n",
    "for _class, nb in sorted(hist.items()):\n",
    "    if _class == 15:\n",
    "        for pred in predictions[tot:tot+nb]:\n",
    "            print(test_dataset.classes[int(pred)])\n",
    "    accuracy_class[int(_class)] = accuracy_score(labels_test[tot:tot+nb], predictions[tot:tot+nb])*100\n",
    "    tot += nb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### accuracy per class\n",
    "plt.plot(accuracy_class, 'o')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset.classes[np.argmax(accuracy_class)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset.classes[np.argmin(accuracy_class)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add rotation perturbations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_3_test = test_rot_dataset.get_labels()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "from sklearn.metrics import accuracy_score\n",
    "# hist = Counter(predictions)\n",
    "hist = Counter(labels_3_test)\n",
    "tot = 0\n",
    "accuracy_class = np.empty((40,))\n",
    "for _class, nb in sorted(hist.items()):\n",
    "    accuracy_class[int(_class)] = accuracy_score(labels_3_test[tot:tot+nb], predictions_rot[tot:tot+nb])*100\n",
    "    tot += nb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(accuracy_class, 'o')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Accuracy per class seems similar, and undistinguishable classes worsens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 evolution of logits for a specific class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = test_dataset.files\n",
    "files = [file for file in files if 'flower_pot' in file]\n",
    "files = files[:64]\n",
    "batch_1 = np.vstack(test_dataset.get_npy_file(files))\n",
    "batch_1_rot = np.vstack(test_rot_dataset.get_npy_file(files))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset.set_transform(transform_shift)\n",
    "batch_1_shift = np.vstack(test_dataset.get_npy_file(files))\n",
    "test_dataset.set_transform(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probs = model.probs(batch_1, 40)\n",
    "probs_shift = model.probs(batch_1_shift, 40)\n",
    "probs_rot = model.probs(batch_1_rot, 40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obj = 17\n",
    "plt.plot(probs[obj,:], 'o', markersize=10, label='normal')\n",
    "# plt.plot(probs_shift[obj,:], 'o', label='90 shift')\n",
    "plt.plot(probs_rot[3*obj:3*obj+3, :].T, 'o', markersize=4, label = 'rotx')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_max = np.argmax(probs[obj,:])\n",
    "test_dataset.classes[class_max]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_max = np.argmax(probs_rot[3*obj:3*obj+3, :].mean(axis=0))\n",
    "test_dataset.classes[class_max]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cm = plt.cm.RdBu_r\n",
    "cm.set_under('w')\n",
    "cmin = np.min(batch_1[:,:,0])\n",
    "cmax = np.max(batch_1[:,:,0])\n",
    "hp.orthview(batch_1[obj,:,0], rot=(0,0,0), title=files[obj], nest=True, cmap=cm, min=cmin, max=cmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(3):\n",
    "    plt.figure()\n",
    "    hp.orthview(batch_1_rot[3*obj+i,:,0], rot=(0,0,0), title=files[obj], nest=True, cmap=cm, min=cmin, max=cmax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Class distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _print_histogram(nclass, labels_train, labels_min=None, ylim=None):\n",
    "    if labels_train is None:\n",
    "        return\n",
    "    import matplotlib.pyplot as plt\n",
    "    from collections import Counter\n",
    "    hist_train=Counter(labels_train)\n",
    "    if labels_min is not None:\n",
    "        hist_min = Counter(labels_min)\n",
    "        hist_temp = hist_train - hist_min\n",
    "        hist_min = hist_min - hist_train\n",
    "        hist_train = hist_temp + hist_min\n",
    "#         for i in range(self.nclass):\n",
    "#             hist_train.append(np.sum(labels_train == i))\n",
    "    labels, values = zip(*hist_train.items())\n",
    "    indexes = np.asarray(labels)\n",
    "#     miss = set(indexes) - set(labels)\n",
    "#     if len(miss) is not 0:\n",
    "#         hist_train.update({elem:0 for elem in miss})\n",
    "#     labels, values = zip(*hist_train.items())\n",
    "    width = 1\n",
    "    plt.bar(labels, values, width)\n",
    "    plt.title(\"labels distribution\")\n",
    "    plt.ylim(0,ylim)\n",
    "    #plt.xticks(indexes + width * 0.5, labels)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_print_histogram(40, labels_test)\n",
    "_print_histogram(40, predictions)\n",
    "_print_histogram(40, labels_test, predictions, ylim=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "tot = predictions.shape[0]\n",
    "hist = Counter(predictions)\n",
    "hist.subtract(Counter(labels_test))\n",
    "p_tot = 0\n",
    "for _class, nb in hist.most_common():\n",
    "    percent = 100*nb/Counter(labels_test)[_class]\n",
    "    p_tot += percent\n",
    "    print(\"{:2.0f}\".format(_class), test_rot_dataset.classes[int(_class)], \"{:.2f}\".format(percent))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(confusion_matrix(labels_test, predictions, range(40)), cmap = plt.cm.gist_heat_r)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
